1 /** 2 Support the automatic implementation of test doubles via programmable mocks. 3 */ 4 module unit_threaded.mock; 5 6 import unit_threaded.from; 7 8 alias Identity(alias T) = T; 9 private enum isPrivate(T, string member) = !__traits(compiles, __traits(getMember, T, member)); 10 11 string implMixinStr(T)() { 12 import std.array : join; 13 import std.format : format; 14 import std.range : iota; 15 import std.traits : functionAttributes, FunctionAttribute, Parameters, 16 ReturnType, arity; 17 import std.conv : text; 18 19 if (!__ctfe) 20 return null; 21 22 string[] lines; 23 24 string getOverload(in string memberName, in int i) { 25 return `Identity!(__traits(getOverloads, T, "%s")[%s])`.format(memberName, i); 26 } 27 28 foreach (memberName; __traits(allMembers, T)) { 29 30 static if (!isPrivate!(T, memberName)) { 31 32 alias member = Identity!(__traits(getMember, T, memberName)); 33 34 static if (__traits(isVirtualMethod, member)) { 35 foreach (i, overload; __traits(getOverloads, T, memberName)) { 36 37 static if (!(functionAttributes!member & FunctionAttribute.const_) 38 && !(functionAttributes!member & FunctionAttribute.const_)) { 39 40 enum overloadName = text(memberName, "_", i); 41 42 enum overloadString = getOverload(memberName, i); 43 lines ~= "private alias %s_parameters = Parameters!(%s);".format(overloadName, 44 overloadString); 45 lines ~= "private alias %s_returnType = ReturnType!(%s);".format(overloadName, 46 overloadString); 47 48 static if (functionAttributes!member & FunctionAttribute.nothrow_) 49 enum tryIndent = " "; 50 else 51 enum tryIndent = ""; 52 53 static if (is(ReturnType!member == void)) 54 enum returnDefault = ""; 55 else { 56 enum varName = overloadName ~ `_returnValues`; 57 lines ~= `%s_returnType[] %s;`.format(overloadName, varName); 58 lines ~= ""; 59 enum returnDefault = [ 60 ` if(` ~ varName ~ `.length > 0) {`, 61 ` auto ret = ` ~ varName ~ `[0];`, 62 ` ` ~ varName ~ ` = ` ~ varName ~ `[1..$];`, 63 ` return ret;`, 64 ` } else`, 65 ` return %s_returnType.init;`.format(overloadName) 66 ]; 67 } 68 69 lines ~= `override ` ~ overloadName ~ "_returnType " ~ memberName ~ typeAndArgsParens!( 70 Parameters!overload)(overloadName) ~ " " 71 ~ functionAttributesString!member ~ ` {`; 72 73 static if (functionAttributes!member & FunctionAttribute.nothrow_) 74 lines ~= "try {"; 75 76 lines ~= tryIndent ~ ` calledFuncs ~= "` ~ memberName ~ `";`; 77 lines ~= tryIndent ~ ` calledValues ~= tuple` ~ argNamesParens( 78 arity!member) ~ `.to!string;`; 79 80 static if (functionAttributes!member & FunctionAttribute.nothrow_) 81 lines ~= " } catch(Exception) {}"; 82 83 lines ~= returnDefault; 84 85 lines ~= `}`; 86 lines ~= ""; 87 } 88 } 89 } 90 } 91 } 92 93 return lines.join("\n"); 94 } 95 96 private string argNamesParens(int N) @safe pure { 97 if (!__ctfe) 98 return null; 99 return "(" ~ argNames(N) ~ ")"; 100 } 101 102 private string argNames(int N) @safe pure { 103 import std.range; 104 import std.algorithm; 105 import std.conv; 106 107 if (!__ctfe) 108 return null; 109 return iota(N).map!(a => "arg" ~ a.to!string).join(", "); 110 } 111 112 private string typeAndArgsParens(T...)(string prefix) { 113 import std.array; 114 import std.conv; 115 import std.format : format; 116 117 if (!__ctfe) 118 return null; 119 120 string[] parts; 121 122 foreach (i, t; T) 123 parts ~= "%s_parameters[%s] arg%s".format(prefix, i, i); 124 return "(" ~ parts.join(", ") ~ ")"; 125 } 126 127 private string functionAttributesString(alias F)() { 128 import std.traits : functionAttributes, FunctionAttribute; 129 import std.array : join; 130 131 if (!__ctfe) 132 return null; 133 134 string[] parts; 135 136 const attrs = functionAttributes!F; 137 138 if (attrs & FunctionAttribute.pure_) 139 parts ~= "pure"; 140 if (attrs & FunctionAttribute.nothrow_) 141 parts ~= "nothrow"; 142 if (attrs & FunctionAttribute.trusted) 143 parts ~= "@trusted"; 144 if (attrs & FunctionAttribute.safe) 145 parts ~= "@safe"; 146 if (attrs & FunctionAttribute.nogc) 147 parts ~= "@nogc"; 148 if (attrs & FunctionAttribute.system) 149 parts ~= "@system"; 150 // const and immutable can't be done since the mock needs 151 // to alter state 152 // if(attrs & FunctionAttribute.const_) parts ~= "const"; 153 // if(attrs & FunctionAttribute.immutable_) parts ~= "immutable"; 154 if (attrs & FunctionAttribute.shared_) 155 parts ~= "shared"; 156 157 return parts.join(" "); 158 } 159 160 mixin template MockImplCommon() { 161 bool _verified; 162 string[] expectedFuncs; 163 string[] calledFuncs; 164 string[] expectedValues; 165 string[] calledValues; 166 167 void expect(string funcName, V...)(auto ref V values) { 168 import std.conv : to; 169 import std.typecons : tuple; 170 171 expectedFuncs ~= funcName; 172 static if (V.length > 0) 173 expectedValues ~= tuple(values).to!string; 174 else 175 expectedValues ~= ""; 176 } 177 178 void expectCalled(string func, string file = __FILE__, size_t line = __LINE__, V...)( 179 auto ref V values) { 180 expect!func(values); 181 verify(file, line); 182 _verified = false; 183 } 184 185 void verify(string file = __FILE__, size_t line = __LINE__) @safe pure { 186 import std.range : repeat, take, join; 187 import std.conv : to; 188 import unit_threaded.should : fail, UnitTestException; 189 190 if (_verified) 191 fail("Mock already _verified", file, line); 192 193 _verified = true; 194 195 for (int i = 0; i < expectedFuncs.length; ++i) { 196 197 if (i >= calledFuncs.length) 198 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] ~ " did not happen", 199 file, line); 200 201 if (expectedFuncs[i] != calledFuncs[i]) 202 fail("Expected nth " ~ i.to!string ~ " call to " ~ expectedFuncs[i] 203 ~ " but got " ~ calledFuncs[i] ~ " instead", file, line); 204 205 if (expectedValues[i] != calledValues[i] && expectedValues[i] != "") 206 throw new UnitTestException([expectedFuncs[i] ~ " was called with unexpected " ~ calledValues[i], 207 " ".repeat.take(expectedFuncs[i].length + 4) 208 .join ~ "instead of the expected " ~ expectedValues[i]], file, line); 209 } 210 } 211 } 212 213 private enum isString(alias T) = is(typeof(T) == string); 214 215 /** 216 A mock object that conforms to an interface/class. 217 */ 218 struct Mock(T) { 219 220 MockAbstract _impl; 221 alias _impl this; 222 223 class MockAbstract : T { 224 import std.conv : to; 225 import std.traits : Parameters, ReturnType; 226 import std.typecons : tuple; 227 228 //pragma(msg, "\nimplMixinStr for ", T, "\n\n", implMixinStr!T, "\n\n"); 229 mixin(implMixinStr!T); 230 mixin MockImplCommon; 231 } 232 233 /// 234 this(int /* force constructor*/ ) { 235 _impl = new MockAbstract; 236 } 237 238 /// 239 ~this() pure @safe { 240 if (!_verified) 241 verify; 242 } 243 244 /// Set the returnValue of a function to certain values. 245 void returnValue(string funcName, V...)(V values) { 246 assertFunctionIsVirtual!funcName; 247 return returnValue!(0, funcName)(values); 248 } 249 250 /** 251 This version takes overloads into account. i is the overload 252 index. e.g.: 253 --------- 254 interface Interface { void foo(int); void foo(string); } 255 auto m = mock!Interface; 256 m.returnValue!(0, "foo"); // int overload 257 m.returnValue!(1, "foo"); // string overload 258 --------- 259 */ 260 void returnValue(int i, string funcName, V...)(V values) { 261 assertFunctionIsVirtual!funcName; 262 import std.conv : text; 263 264 enum varName = funcName ~ text(`_`, i, `_returnValues`); 265 foreach (v; values) 266 mixin(varName ~ ` ~= v;`); 267 } 268 269 private static void assertFunctionIsVirtual(string funcName)() { 270 alias member = Identity!(__traits(getMember, T, funcName)); 271 272 static assert(__traits(isVirtualMethod, member), 273 "Cannot use returnValue on '" ~ funcName ~ "'"); 274 } 275 } 276 277 private string importsString(string module_, string[] Modules...) { 278 if (!__ctfe) 279 return null; 280 281 auto ret = `import ` ~ module_ ~ ";\n"; 282 foreach (extraModule; Modules) { 283 ret ~= `import ` ~ extraModule ~ ";\n"; 284 } 285 return ret; 286 } 287 288 /// Helper function for creating a Mock object. 289 auto mock(T)() { 290 return Mock!T(0); 291 } 292 293 /// 294 @("mock interface positive test no params") 295 @safe pure unittest { 296 interface Foo { 297 int foo(int, string) @safe pure; 298 void bar() @safe pure; 299 } 300 301 int fun(Foo f) { 302 return 2 * f.foo(5, "foobar"); 303 } 304 305 auto m = mock!Foo; 306 m.expect!"foo"; 307 fun(m); 308 } 309 310 /// 311 @("mock interface positive test with params") 312 @safe pure unittest { 313 import unit_threaded.asserts; 314 315 interface Foo { 316 int foo(int, string) @safe pure; 317 void bar() @safe pure; 318 } 319 320 int fun(Foo f) { 321 return 2 * f.foo(5, "foobar"); 322 } 323 324 auto m = mock!Foo; 325 m.expect!"foo"(5, "foobar"); 326 fun(m); 327 } 328 329 /// 330 @("interface expectCalled") 331 @safe pure unittest { 332 interface Foo { 333 int foo(int, string) @safe pure; 334 void bar() @safe pure; 335 } 336 337 int fun(Foo f) { 338 return 2 * f.foo(5, "foobar"); 339 } 340 341 auto m = mock!Foo; 342 fun(m); 343 m.expectCalled!"foo"(5, "foobar"); 344 } 345 346 /// 347 @("interface return value") 348 @safe pure unittest { 349 350 interface Foo { 351 int timesN(int i) @safe pure; 352 } 353 354 int fun(Foo f) { 355 return f.timesN(3) * 2; 356 } 357 358 auto m = mock!Foo; 359 m.returnValue!"timesN"(42); 360 immutable res = fun(m); 361 assert(res == 84); 362 } 363 364 /// 365 @("interface return values") 366 @safe pure unittest { 367 368 interface Foo { 369 int timesN(int i) @safe pure; 370 } 371 372 int fun(Foo f) { 373 return f.timesN(3) * 2; 374 } 375 376 auto m = mock!Foo; 377 m.returnValue!"timesN"(42, 12); 378 assert(fun(m) == 84); 379 assert(fun(m) == 24); 380 assert(fun(m) == 0); 381 } 382 383 struct ReturnValues(string function_, T...) 384 if (from!"std.meta".allSatisfy!(isValue, T)) { 385 alias funcName = function_; 386 alias Values = T; 387 388 static auto values() { 389 typeof(T[0])[] ret; 390 foreach (val; T) { 391 ret ~= val; 392 } 393 return ret; 394 } 395 } 396 397 enum isReturnValue(alias T) = is(T : ReturnValues!U, U...); 398 enum isValue(alias T) = is(typeof(T)); 399 400 /** 401 Version of mockStruct that accepts 0 or more values of the same 402 type. Whatever function is called on it, these values will 403 be returned one by one. The limitation is that if more than one 404 function is called on the mock, they all return the same type 405 */ 406 auto mockStruct(T...)(auto ref T returns) { 407 408 struct Mock { 409 410 MockImpl* _impl; 411 alias _impl this; 412 413 static struct MockImpl { 414 415 static if (T.length > 0) { 416 alias FirstType = typeof(returns[0]); 417 private FirstType[] _returnValues; 418 } 419 420 mixin MockImplCommon; 421 422 auto opDispatch(string funcName, V...)(auto ref V values) { 423 424 import std.conv : to; 425 import std.typecons : tuple; 426 427 calledFuncs ~= funcName; 428 calledValues ~= tuple(values).to!string; 429 430 static if (T.length > 0) { 431 432 if (_returnValues.length == 0) 433 return typeof(_returnValues[0]).init; 434 auto ret = _returnValues[0]; 435 _returnValues = _returnValues[1 .. $]; 436 return ret; 437 } 438 } 439 } 440 } 441 442 Mock m; 443 m._impl = new Mock.MockImpl; 444 static if (T.length > 0) { 445 foreach (r; returns) 446 m._impl._returnValues ~= r; 447 } 448 449 return m; 450 } 451 452 /** 453 Version of mockStruct that accepts a compile-time mapping 454 of function name to return values. Each template parameter 455 must be a value of type `ReturnValues` 456 */ 457 auto mockStruct(T...)() 458 if (T.length > 0 && from!"std.meta".allSatisfy!(isReturnValue, T)) { 459 460 struct Mock { 461 mixin MockImplCommon; 462 463 int[string] _retIndices; 464 465 auto opDispatch(string funcName, V...)(auto ref V values) { 466 467 import std.conv : to; 468 import std.typecons : tuple; 469 470 calledFuncs ~= funcName; 471 calledValues ~= tuple(values).to!string; 472 473 foreach (retVal; T) { 474 static if (retVal.funcName == funcName) { 475 return retVal.values[_retIndices[funcName]++]; 476 } 477 } 478 } 479 480 auto lefoofoo() { 481 return T[0].values[_retIndices["greet"]++]; 482 } 483 484 } 485 486 Mock mock; 487 488 foreach (retVal; T) { 489 mock._retIndices[retVal.funcName] = 0; 490 } 491 492 return mock; 493 } 494 495 /// 496 @("mock struct positive") 497 @safe pure unittest { 498 void fun(T)(T t) { 499 t.foobar; 500 } 501 502 auto m = mockStruct; 503 m.expect!"foobar"; 504 fun(m); 505 m.verify; 506 } 507 508 /// 509 @("mock struct values positive") 510 @safe pure unittest { 511 void fun(T)(T t) { 512 t.foobar(2, "quux"); 513 } 514 515 auto m = mockStruct; 516 m.expect!"foobar"(2, "quux"); 517 fun(m); 518 m.verify; 519 } 520 521 /// 522 @("struct return value") 523 @safe pure unittest { 524 525 int fun(T)(T f) { 526 return f.timesN(3) * 2; 527 } 528 529 auto m = mockStruct(42, 12); 530 assert(fun(m) == 84); 531 assert(fun(m) == 24); 532 assert(fun(m) == 0); 533 m.expectCalled!"timesN"; 534 } 535 536 /// 537 @("struct expectCalled") 538 @safe pure unittest { 539 void fun(T)(T t) { 540 t.foobar(2, "quux"); 541 } 542 543 auto m = mockStruct; 544 fun(m); 545 m.expectCalled!"foobar"(2, "quux"); 546 } 547 548 /// 549 @("mockStruct different return types for different functions") 550 @safe pure unittest { 551 auto m = mockStruct!(ReturnValues!("length", 5), ReturnValues!("greet", "hello")); 552 assert(m.length == 5); 553 assert(m.greet("bar") == "hello"); 554 m.expectCalled!"length"; 555 m.expectCalled!"greet"("bar"); 556 } 557 558 /// 559 @("mockStruct different return types for different functions and multiple return values") 560 @safe pure unittest { 561 auto m = mockStruct!(ReturnValues!("length", 5, 3), ReturnValues!("greet", "hello", "g'day")); 562 assert(m.length == 5); 563 m.expectCalled!"length"; 564 assert(m.length == 3); 565 m.expectCalled!"length"; 566 567 assert(m.greet("bar") == "hello"); 568 m.expectCalled!"greet"("bar"); 569 assert(m.greet("quux") == "g'day"); 570 m.expectCalled!"greet"("quux"); 571 } 572 573 /** 574 A mock struct that always throws. 575 */ 576 auto throwStruct(E = from!"unit_threaded.should".UnitTestException, R = void)() { 577 578 struct Mock { 579 580 R opDispatch(string funcName, string file = __FILE__, size_t line = __LINE__, V...)( 581 auto ref V values) { 582 throw new E(funcName ~ " was called", file, line); 583 } 584 } 585 586 return Mock(); 587 } 588 589 /// 590 @("throwStruct default") 591 @safe pure unittest { 592 import std.exception : assertThrown; 593 import unit_threaded.should : UnitTestException; 594 595 auto m = throwStruct; 596 assertThrown!UnitTestException(m.foo); 597 assertThrown!UnitTestException(m.bar(1, "foo")); 598 }